# ---> Neuro-inspired Sparse Training (NIST) manuscript <---
# Authors: Mohsen Kamelian Rad, Sotiris Moschoyiannis, Lu Yin, Roman Bauer---
# This script by default runs NIST on DeiT_small
# For modifying the sparsity ratio, pruning epoch, etc, modify "train_deit_timm" call at the end of the training loop.

import os
import glob
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torchvision
import timm
import time
import numpy as np
from tqdm import tqdm
from nist2 import nist_layer

torch.cuda.empty_cache()
torch.cuda.ipc_collect()


#from setup import *
from load_dataset import get_dataset
dev = torch.cuda.get_device_name(device=None)
print(dev)
# List all four parts
start_time = time.time()
print("timing started...")

train_dirs = [
    "/mnt/fast/nobackup/users/mk02339/nist/imagenet100/train_all"
]
val_dir = "/mnt/fast/nobackup/users/mk02339/nist/imagenet100/val.X"
train_loader, val_loader = get_dataset(train_dirs, val_dir, 
                            batch_size= 128, num_workers=8, 
                            shuffle_train= True)
# Check a batch
'''images, labels = next(iter(train_loader))
print("Image batch shape:", images.shape)  # e.g., torch.Size([8, 3, 224, 224])
print("Label batch:", labels)
print("Unique classes in labels:", torch.unique(labels))

# Get class names
class_names = train_loader.dataset.datasets[0].classes \
              if hasattr(train_loader.dataset, 'datasets') else train_loader.dataset.classes
print("Number of classes:", len(class_names))
print("Sample classes:", class_names[:10])
'''
training_accs = []
validation_accs = []

def replace_linear_with_nist(module):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear):
            in_features = child.in_features
            out_features = child.out_features
            bias = child.bias is not None
            #print('Block name is: ', name)
            if name == 'fc1':                       # For topographical input mapping in Multi-layer perceptrons at the end of the transformer block
                nist_lay = nist_layer(in_features, out_features, neuroseed_factor = round(out_features/2))
            elif name == 'head':                       # Dense HEad
                nist_lay = nist_layer(in_features, out_features, neuroseed_factor = round(out_features))
            else:
                nist_lay = nist_layer(in_features, out_features, neuroseed_factor = round(out_features/2))
            setattr(module, name, nist_lay)
        else:
            replace_linear_with_nist(child)  # Recursive replacement
def magnitude_prune(model, sparsity_target = 0.5):
    # Gathers all weights into one flat tensor
    all_weights= []
    for module in model.modules():
        if isinstance(module, nist_layer):
            all_weights.append(module.weight.data.abs().view(-1))
    all_weights = torch.cat(all_weights)

    # Determine threshold at target sparsity
    k = int(sparsity_target * all_weights.numel())  # total number of final pruned weights
    k2 = all_weights.numel() - len(all_weights.nonzero())
    print(f"Pruning {k-k2} more parameters")
    if k == 0:
        return
    threshold, _ = torch.kthvalue(all_weights, k)
    print("Global pruning threshold: ", threshold)

    # Apply new masks per layer
    for module in model.modules():
        if isinstance(module, nist_layer):
            new_mask = (module.weight.data.abs() > threshold).float()
            module.apply_pruning_mask(new_mask)          
def train_deit_timm(
    epochs=300,
    lr=1e-3,
    warmup_epochs = 5,
    device=None,
    use_amp=True,
    #model_name="deit_base_distilled_patch16_224",
    model_name="deit_small_patch16_224",
    pretrained=False,
    prune_epoch=2,
    target_sparsity= 0.5
):
    from timm.data import Mixup
    from timm.utils import accuracy
    from timm.scheduler.cosine_lr import CosineLRScheduler
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    model = timm.create_model(model_name, pretrained=pretrained, num_classes=100)
    replace_linear_with_nist(model)
    model.to(device)
    #model = nn.DataParallel(model)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.0)  # label smoothing
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05,  betas=(0.9, 0.999))
    #optimizer = optim.AdamW(model.parameters(), lr=lr)
    lr_scheduler = CosineLRScheduler(
        optimizer,
        t_initial=epochs - warmup_epochs,
        lr_min=1e-5,                     # 1e-5 originally
        warmup_lr_init=lr / 25,
        warmup_t=warmup_epochs,
        cycle_decay=0.1,
        cycle_limit=1
    )    
    mixup_fn = Mixup(
        mixup_alpha=0.8, cutmix_alpha=1.0,
        label_smoothing=0.1, num_classes=100
    )
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp and device.startswith("cuda"))
    
    best_acc = 0.0
    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        correct = 0
        total = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} Train", ncols=120, mininterval=60)     # Slows down tqdm updates

        for imgs, labels in pbar:                      # I can avoid using tqdm and pbar and only use train_loader to avoid blown up logs
            imgs = imgs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            # Apply Mixup/CutMix
            imgs, labels = mixup_fn(imgs, labels)

            optimizer.zero_grad(set_to_none = True)
            with torch.cuda.amp.autocast(enabled=use_amp and device.startswith("cuda")):
                outputs = model(imgs)
                logits = outputs if isinstance(outputs, torch.Tensor) else outputs.logits
                loss = criterion(logits, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item() * imgs.size(0)
            preds = logits.argmax(dim=1)
            # For accuracy we use argmax vs original labels (mixup makes this noisy)
            correct += (preds == labels.argmax(dim=1)).sum().item() if labels.ndim > 1 else (preds == labels).sum().item()
            total += labels.size(0)
            #correct += (preds == labels).sum().item()
            #total += labels.size(0)

            #pbar.set_postfix(loss=f"{total_loss/total:.4f}", acc=f"{100.*correct/total:.2f}%")

        lr_scheduler.step(epoch)
        train_loss = total_loss / total
        train_acc = 100.0 * correct / total
        print(f"Epoch {epoch} Train loss {train_loss:.4f} acc {train_acc:.2f}%")
        training_accs.append(train_acc)
        np.save('train_accs_95p.npy', training_accs)
        # Validation
        if epoch == prune_epoch:
            print(f"Applying magnitude pruning at epoch {epoch}...")
            magnitude_prune(model, target_sparsity)
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for imgs, labels in tqdm(val_loader, desc="Validation", ncols=120, mininterval=60):
                imgs = imgs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                with torch.cuda.amp.autocast(enabled=use_amp and device.startswith("cuda")):
                    outputs = model(imgs)
                    logits = outputs if isinstance(outputs, torch.Tensor) else outputs.logits
                    loss = criterion(logits, labels)

                val_loss += loss.item() * imgs.size(0)
                preds = logits.argmax(dim=1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

        val_loss = val_loss / val_total
        val_acc = 100.0 * val_correct / val_total
        print(f"Epoch {epoch} Val loss {val_loss:.4f} acc {val_acc:.2f}%")
        validation_accs.append(val_acc)
        np.save('val_accs_95p.npy', validation_accs)
        torch.cuda.empty_cache()
        '''if val_acc > best_acc:
            best_acc = val_acc
            torch.save({
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "epoch": epoch,
                "class_to_idx": train_ds.class_to_idx
            }, "best_deit_timm.pth")
            #print(f"Saved new best model (acc {best_acc:.2f}%)")'''

    #print("Training finished. Best val acc: {:.2f}%".format(best_acc))
    return model

# ---------------------------
# 5) Run script (example)
# ---------------------------
'''if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int)
    parser.add_argument("--lr", type=float, default=5e-5)
    args = parser.parse_args()
'''
train_deit_timm(
    epochs=400,
    lr=1e-3,
    prune_epoch=200,
    target_sparsity= 0.95,
)
end_time = time.time()
total_seconds = end_time - start_time
hours = int(total_seconds // 3600)
minutes = int((total_seconds % 3600) // 60)
seconds = int(total_seconds % 60)
print(f"Training time: {hours} hours, {minutes} Minutes, and {seconds} Seconds.")



